package validator

import (
	"errors"
	"github.com/modern-go/reflect2"
	"github.com/valyala/fasthttp"
	"math"
	"reflect"
	"strconv"
	"strings"
	"unicode/utf8"
)

func Bind(ctx *fasthttp.RequestCtx, objs interface{}) error {

	rt := reflect2.TypeOf(objs)
	rtElem := rt

	if rt.Kind() != reflect.Ptr {
		return errors.New("argument 2 should be map or ptr")
	}

	rt = rt.(reflect2.PtrType).Elem()
	rtElem = rt

	if rtElem.Kind() != reflect.Struct {
		return errors.New("non-structure type not supported yet")
	}

	s := rtElem.(reflect2.StructType)

	for i := 0; i < s.NumField(); i++ {

		f := s.Field(i)

		min := int64(0)
		max := int64(math.MaxInt64)
		name := f.Tag().Get("name")
		rule := f.Tag().Get("rule")
		msg := f.Tag().Get("msg")
		required := f.Tag().Get("required")
		def := f.Tag().Get("default")
		nums := len(def)

		if len(name) == 0 {
			name = strings.ToLower(f.Name())
		}

		if v, err := strconv.ParseInt(f.Tag().Get("min"), 10, 64); err == nil {
			min = v
		}
		if v, err := strconv.ParseInt(f.Tag().Get("max"), 10, 64); err == nil {
			max = v
		}

		defaultVal := ""
		if string(ctx.Method()) == "GET" {
			defaultVal = strings.TrimSpace(string(ctx.QueryArgs().Peek(name)))
		} else if string(ctx.Method()) == "POST" {
			defaultVal = strings.TrimSpace(string(ctx.PostArgs().Peek(name)))
		}

		check := true //默认需要校验
		if defaultVal == "" {
			if nums > 0 {
				defaultVal = def
			}

			// 是必选参数，且没有默认值
			if required != "0" && defaultVal == "" {
				if rule == "none" {
					check = false
				} else {
					return errors.New(name + " not found")
				}
			} else {
				check = false
			}
		}

		if check {
			switch rule {
			case "digit":
				if !CheckStringDigit(defaultVal) || !CheckIntScope(defaultVal, min, max) {
					return errors.New(msg)
				}
			case "digitString":
				if !CheckStringDigit(defaultVal) || !CheckStringLength(defaultVal, int(min), int(max)) {
					return errors.New(msg)
				}
			case "sDigit":
				if !CheckStringCommaDigit(defaultVal) || !CheckStringLength(defaultVal, int(min), int(max)) {
					return errors.New(msg)
				}
			case "sAlpha":
				if !CheckStringCommaAlpha(defaultVal) || !CheckStringLength(defaultVal, int(min), int(max)) {
					return errors.New(msg)
				}
			case "url":
				if !CheckUrl(defaultVal) {
					return errors.New(msg)
				}
			case "alnum":
				if !CheckStringAlnum(defaultVal) || !CheckStringLength(defaultVal, int(min), int(max)) {
					return errors.New(msg)
				}
			case "priv":
				if !isPriv(defaultVal) {
					return errors.New(msg)
				}
			case "dateTime":
				if !CheckDateTime(defaultVal) {
					return errors.New(msg)
				}
			case "date":
				if !CheckDate(defaultVal) {
					return errors.New(msg)
				}
			case "time":
				if !checkTime(defaultVal) {
					return errors.New(msg)
				}
			case "chn":
				if !CheckStringCHN(defaultVal) {
					return errors.New(msg)
				}
			case "module":
				if !CheckStringModule(defaultVal) || !CheckStringLength(defaultVal, int(min), int(max)) {
					return errors.New(msg)
				}
			case "float":
				if !CheckFloat(defaultVal) {
					return errors.New(msg)
				}
			case "vnphone":
				if !IsVietnamesePhone(defaultVal) {
					return errors.New(msg)
				}
			case "filter":
				if !CheckStringLength(defaultVal, int(min), int(max)) {
					return errors.New(msg)
				}

				defaultVal = FilterInjection(defaultVal)
			case "uname": //会员账号
				if !CheckUName(defaultVal, int(min), int(max)) {
					return errors.New(msg)
				}
			case "upwd": //会员密码
				if !CheckUPassword(defaultVal, int(min), int(max)) {
					return errors.New(msg)
				}
			default:
				break
			}
		}

		switch f.Type().Kind() {
		case reflect.Bool:
			if val, err := strconv.ParseBool(defaultVal); err == nil {
				f.UnsafeSet(reflect2.PtrOf(objs), reflect2.PtrOf(val))
			}
		case reflect.Int:
			if val, err := strconv.Atoi(defaultVal); err == nil {
				f.UnsafeSet(reflect2.PtrOf(objs), reflect2.PtrOf(val))
			}
		case reflect.Int8:
			if val, err := strconv.ParseInt(defaultVal, 10, 8); err == nil {
				f.UnsafeSet(reflect2.PtrOf(objs), reflect2.PtrOf(val))
			}
		case reflect.Int16:
			if val, err := strconv.ParseInt(defaultVal, 10, 16); err == nil {
				f.UnsafeSet(reflect2.PtrOf(objs), reflect2.PtrOf(val))
			}
		case reflect.Int32:
			if val, err := strconv.ParseInt(defaultVal, 10, 32); err == nil {
				f.UnsafeSet(reflect2.PtrOf(objs), reflect2.PtrOf(val))
			}
		case reflect.Int64:
			if val, err := strconv.ParseInt(defaultVal, 10, 64); err == nil {
				f.UnsafeSet(reflect2.PtrOf(objs), reflect2.PtrOf(val))
			}
		case reflect.Uint:
			if val, err := strconv.ParseUint(defaultVal, 10, 64); err == nil {
				f.UnsafeSet(reflect2.PtrOf(objs), reflect2.PtrOf(val))
			}
		case reflect.Uint8:
			if val, err := strconv.ParseUint(defaultVal, 10, 8); err == nil {
				f.UnsafeSet(reflect2.PtrOf(objs), reflect2.PtrOf(val))
			}
		case reflect.Uint16:
			if val, err := strconv.ParseUint(defaultVal, 10, 16); err == nil {
				f.UnsafeSet(reflect2.PtrOf(objs), reflect2.PtrOf(val))
			}
		case reflect.Uint32:
			if val, err := strconv.ParseUint(defaultVal, 10, 32); err == nil {
				f.UnsafeSet(reflect2.PtrOf(objs), reflect2.PtrOf(val))
			}
		case reflect.Uint64:
			if val, err := strconv.ParseUint(defaultVal, 10, 64); err == nil {
				f.UnsafeSet(reflect2.PtrOf(objs), reflect2.PtrOf(val))
			}
		case reflect.Uintptr:
			if val, err := strconv.ParseUint(defaultVal, 10, 64); err == nil {
				f.UnsafeSet(reflect2.PtrOf(objs), reflect2.PtrOf(val))
			}
		case reflect.Float32:
			if val, err := strconv.ParseFloat(defaultVal, 32); err == nil {
				f.UnsafeSet(reflect2.PtrOf(objs), reflect2.PtrOf(val))
			}
		case reflect.Float64:
			if val, err := strconv.ParseFloat(defaultVal, 64); err == nil {
				f.UnsafeSet(reflect2.PtrOf(objs), reflect2.PtrOf(val))
			}
		case reflect.String:
			f.UnsafeSet(reflect2.PtrOf(objs), reflect2.PtrOf(defaultVal))
		}
	}

	return nil
}

type cb func(content string) bool

var verify_cb = map[string]cb{
	"uname":      cTypeUName,
	"alnumHan":   CheckAlnumHan,
	"unames":     cTypeUNames,
	"upwd":       cTypeUPassword,
	"seamo":      cTypeSeamo,
	"digit":      CtypeDigit,
	"alnum":      CtypeAlnum,
	"url":        CheckUrl,
	"dateTime":   CheckDateTime,
	"commaDigit": CheckStringCommaDigit,
}

func BindArgs(fctx *fasthttp.RequestCtx, rules map[string]string) (map[string]string, error) {

	res := map[string]string{}
	rule := fasthttp.Args{}

	for name, cond := range rules {

		val := ""
		if fctx.IsGet() {
			val = strings.TrimSpace(string(fctx.QueryArgs().Peek(name)))

		} else if fctx.IsPost() {
			val = strings.TrimSpace(string(fctx.PostArgs().Peek(name)))
		}

		if cond == "-" {
			res[name] = val
			continue
		}

		rule.Parse(cond)

		fn := string(rule.Peek("fn"))
		ty := string(rule.Peek("ty"))
		msg := string(rule.Peek("msg"))
		lte := rule.GetUintOrZero("lte")
		gte := rule.GetUintOrZero("gte")
		required := rule.GetBool("required")
		check := true
		ll := utf8.RuneCountInString(val)
		if !required && ll == 0 {
			check = false
		}

		if ty == "string" && check {

			if lte > ll {
				return res, errors.New(msg)
			}
			if gte > 0 && gte < ll {
				return res, errors.New(msg)
			}
		} else if ty == "int" && check {
			i, err := strconv.Atoi(val)
			if err != nil {
				return res, errors.New(msg)
			}
			if lte > i {
				return res, errors.New(msg)
			}
			if gte > 0 && gte < i {
				return res, errors.New(msg)
			}
		}

		if f, ok := verify_cb[fn]; check && ok {
			success := f(val)
			if !success {
				return res, errors.New(msg)
			}
		}

		res[name] = val
		rule.Reset()
	}

	return res, nil
}
